import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import pytorch_lightning as pl
import pandas as pd

class PredRNNAWSDataset(Dataset):

    def __init__(self, csv_file, root_dir):
        """
        :param csv_file: the datetime of samples
        :param data_dir: dataset's parent folder
        """
        super(PredRNNAWSDataset, self).__init__()
        self.fname = pd.read_csv(csv_file, header=None, parse_dates=[0, ])
        self.file_prec_out = pd.DataFrame([pd.date_range(start=itime, periods=19,
                                                         freq="h").strftime("%Y/%m/%Y%m%d%H%M.npy") for itime in
                                           self.fname.iloc[:, 0]])
        self.file_prec_in = pd.DataFrame([pd.date_range(start=itime, periods=6,
                                                        freq="h").strftime("%Y/%m/%Y%m%d%H%M.npy") for itime in
                                          self.fname.iloc[:, 0]])
        self.file_pwarfs = self.fname.iloc[:, 0].dt.strftime("%Y/%m/PWAFS_%Y%m%d%H.npy")
        self.root_dir = root_dir

    def __getitem__(self, item):
        if torch.is_tensor(item):
            item = item.tolist()
        data_outs = np.asarray([np.load(os.path.join(self.root_dir, "Prec", ifilename)) for
                                ifilename in self.file_prec_out.iloc[item]], dtype=np.float32)
        data_pwafs = np.load(os.path.join(self.root_dir, "PWAFS", self.file_pwarfs.iloc[item]))
        data_prec = np.asarray([np.load(os.path.join(self.root_dir, "Prec", ifilename)) for
                                ifilename in self.file_prec_in.iloc[item]], dtype=np.float32)
        return data_pwafs[:, :, ...], data_prec[:, np.newaxis,...], data_outs[:, np.newaxis, ...]

    def __len__(self):
        return self.fname.shape[0]


class PredRNNAWSDataModule(pl.LightningDataModule):

    def __init__(self, train_file, val_file, root_dir, test_file,
                 num_workers=8, batch_size=10, pin_memory=True, **kwargs):
        super().__init__()
        self.train_file = train_file
        self.val_file = val_file
        self.test_file = test_file
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.root_dir = root_dir

    def setup(self, stage=None):
        self.train = PredRNNAWSDataset(self.train_file, self.root_dir)
        self.val = PredRNNAWSDataset(self.val_file, self.root_dir)
        self.test = PredRNNAWSDataset(self.test_file, self.root_dir)

    def train_dataloader(self):
        return DataLoader(self.train, num_workers=self.num_workers, shuffle=True, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.val, num_workers=self.num_workers, shuffle=False, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.test, num_workers=self.num_workers, shuffle=False, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, persistent_workers=True)

class UNetAWSDataset(Dataset):

    def __init__(self, csv_file, root_dir):
        """
        :param csv_file: 已经分割好的train, val, test的csv文件
        :param data_dir: 存放数据的路径
        :param is_inference: 是否是infenerce的数据， 默认是否
        """
        super().__init__()
        self.fname = pd.read_csv(csv_file, header=None, parse_dates=[0, ])
        self.file_prec_out = pd.DataFrame([pd.date_range(start=itime, periods=19,
                                                         freq="h").strftime("%Y/%m/%Y%m%d%H%M.npy") for itime in
                                           self.fname.iloc[:, 0]])
        self.file_pwarfs = self.fname.iloc[:, 0].dt.strftime("%Y/%m/PWAFS_%Y%m%d%H.npy")
        self.root_dir = root_dir

    def __getitem__(self, item):
        if torch.is_tensor(item):
            item = item.tolist()
        data_outs = np.asarray([np.load(os.path.join(self.root_dir, "Prec", ifilename)) for
                                ifilename in self.file_prec_out.iloc[item]], dtype=np.float32)
        data_pwafs = np.load(os.path.join(self.root_dir, "PWAFS", self.file_pwarfs.iloc[item]))
        return data_pwafs[:, :, ...], data_outs[:, np.newaxis, ...]

    def __len__(self):
        return self.fname.shape[0]


class UNetAWSDataModule(pl.LightningDataModule):

    def __init__(self, train_file, val_file, test_file, root_dir,
                 num_workers=8, batch_size=10, pin_memory=True, **kwargs):

        super().__init__()
        self.train_file = train_file
        self.val_file = val_file
        self.test_file = test_file
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.root_dir = root_dir

    def setup(self, stage=None):
        self.train = UNetAWSDataset(self.train_file, self.root_dir)
        self.val = UNetAWSDataset(self.val_file, self.root_dir)
        self.test = UNetAWSDataset(self.test_file, self.root_dir)

    def train_dataloader(self):
        return DataLoader(self.train, num_workers=self.num_workers, shuffle=True, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.val, num_workers=self.num_workers, shuffle=False, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.test, num_workers=self.num_workers, shuffle=False, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, persistent_workers=True)

class UNetDataset(Dataset):

    def __init__(self, csv_file, root_dir):
        """
        :param csv_file: 已经分割好的train, val, test的csv文件
        :param data_dir: 存放数据的路径
        :param is_inference: 是否是infenerce的数据， 默认是否
        """
        super().__init__()
        self.fname = pd.read_csv(csv_file, header=None, parse_dates=[0, ])
        self.file_prec_out = pd.DataFrame([pd.date_range(start=itime, periods=19,
                                                         freq="h").strftime("%Y/%m/%Y%m%d%H%M.npy") for itime in
                                           self.fname.iloc[:, 0]])
        self.file_pwarfs = self.fname.iloc[:, 0].dt.strftime("%Y/%m/PWAFS_%Y%m%d%H.npy")
        self.root_dir = root_dir

    def __getitem__(self, item):
        if torch.is_tensor(item):
            item = item.tolist()
        data_outs = np.asarray([np.load(os.path.join(self.root_dir, "Prec32", ifilename)) for
                                ifilename in self.file_prec_out.iloc[item]], dtype=np.float32)
        data_pwafs = np.load(os.path.join(self.root_dir, "PWAFSNormF32", self.file_pwarfs.iloc[item]))

        return data_pwafs[:, :, ...], data_outs[6:, np.newaxis, ...]

    def __len__(self):
        return self.fname.shape[0]


class UNetDataModule(pl.LightningDataModule):

    def __init__(self, train_file,
                 val_file,
                 test_file,
                 root_dir,
                 num_workers=8, batch_size=10, pin_memory=True, **kwargs):
        super().__init__()
        self.train_file = train_file
        self.val_file = val_file
        self.test_file = test_file
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.root_dir = root_dir

    def setup(self, stage=None):
        self.train = UNetDataset(self.train_file, self.root_dir)
        self.val = UNetDataset(self.val_file, self.root_dir)
        self.test = UNetDataset(self.test_file, self.root_dir)

    def train_dataloader(self):
        return DataLoader(self.train, num_workers=self.num_workers, shuffle=True, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.val, num_workers=self.num_workers, shuffle=False, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.test, num_workers=self.num_workers, shuffle=False, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, persistent_workers=True)


class PredRNNDataset(Dataset):

    def __init__(self, csv_file, root_dir):
        """
        :param csv_file: 已经分割好的train, val, test的csv文件
        :param data_dir: 存放数据的路径
        :param is_inference: 是否是infenerce的数据， 默认是否
        """
        super().__init__()
        self.fname = pd.read_csv(csv_file, header=None, parse_dates=[0, ])
        self.file_prec_out = pd.DataFrame([pd.date_range(start=itime, periods=19,
                                                         freq="h").strftime("%Y/%m/%Y%m%d%H%M.npy") for itime in
                                           self.fname.iloc[:, 0]])
        self.file_pwarfs = self.fname.iloc[:, 0].dt.strftime("%Y/%m/PWAFS_%Y%m%d%H.npy")
        self.root_dir = root_dir

    def __getitem__(self, item):
        if torch.is_tensor(item):
            item = item.tolist()
        data_outs = np.asarray([np.load(os.path.join(self.root_dir, "Prec32", ifilename)) for
                                ifilename in self.file_prec_out.iloc[item]], dtype=np.float32)
        data_pwafs = np.load(os.path.join(self.root_dir, "PWAFSNormF32", self.file_pwarfs.iloc[item]))

        return data_pwafs[:, :, ...], data_outs[:, np.newaxis, ...]

    def __len__(self):
        return self.fname.shape[0]


class PredRNNDataModule(pl.LightningDataModule):

    def __init__(self, train_file,
                       val_file,
                       test_file,
                       root_dir,
                       num_workers=8, batch_size=10, pin_memory=True, **kwargs):
        super().__init__()
        self.train_file = train_file
        self.val_file = val_file
        self.test_file = test_file
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.root_dir = root_dir

    def setup(self, stage=None):
        self.train = PredRNNDataset(self.train_file, self.root_dir)
        self.val = PredRNNDataset(self.val_file, self.root_dir)
        self.test = PredRNNDataset(self.test_file, self.root_dir)

    def train_dataloader(self):
        return DataLoader(self.train, num_workers=self.num_workers, shuffle=True, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.val, num_workers=self.num_workers, shuffle=False, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.test, num_workers=self.num_workers, shuffle=False, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, persistent_workers=True)